import os
from datetime import datetime

# 读取train_img下的所有文件名
train_img_path = r'../XJU3_filterImg/train_img'
train_img_files = os.listdir(train_img_path)
train_img_files.sort()
# print(train_img_files)

# 读取train_lab下的所有文件名
train_lab_path = r'../XJU3_filterImg/train_lab'
train_lab_files = os.listdir(train_lab_path)
train_lab_files.sort()
# print(train_lab_files)

# 定义train与val的比例
train_ratio = 0.8
val_ratio = 0.2
train_num = int(len(train_img_files) * train_ratio)
val_num = len(train_img_files) - train_num
# 定义统一前缀
Prefix = 'tools/XJU3_filterImg/'
TrainPrefix = 'train_img/'
LabelPrefix = 'train_lab/'
# 获取当前日期 yyyymmdd
date = datetime.now().strftime('%Y%m%d')
# print('当前日期：', date)
# 创建空白txt
train_txt_path = r'./txt/'+date+'train.txt'
train_txt = open(train_txt_path, 'w')
for i in range(len(train_img_files)):
    # train_txt.write(Prefix+TrainPrefix+train_img_files[i] + ' ' + Prefix+LabelPrefix+train_lab_files[i] + '\n')
    if i < train_num:
        train_txt.write(Prefix+TrainPrefix+train_img_files[i] + ' ' + Prefix+LabelPrefix+train_lab_files[i] + '\n')
    # 刷新缓冲区
    train_txt.flush()
train_txt.close()
print('train.txt文件创建完成')
val_txt_path = r'./txt/'+date+'val.txt'
val_txt = open(val_txt_path, 'w')
for i in range(len(train_img_files)):
    if i >= train_num:
        val_txt.write(Prefix+TrainPrefix+train_img_files[i] + ' ' + Prefix+LabelPrefix+train_lab_files[i] + '\n')
    # 刷新缓冲区
    val_txt.flush()
val_txt.close()
print('val.txt文件创建完成')